{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ur8xi4C7S06n"
      },
      "outputs": [],
      "source": [
        "# Copyright 2025 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84f0f73a0f76"
      },
      "source": [
        "# Handling large-scale embedding generation for Vertex AI Vector Search\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/embeddings/large-embs-generation-for-vvs.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fembeddings%2Flarge-embs-generation-for-vvs.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/embeddings/large-embs-generation-for-vvs.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/large-embs-generation-for-vvs.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/large-embs-generation-for-vvs.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/large-embs-generation-for-vvs.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/large-embs-generation-for-vvs.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/large-embs-generation-for-vvs.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/embeddings/large-embs-generation-for-vvs.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>\n",
        "\n",
        "| | |\n",
        "|-|-|\n",
        "| Author(s) |  [Kaz Sato](https://github.com/kazunori279/) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvgnzT1CKxrO"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This tutorial provides a simple example of how to efficiently generate text and multimodal embeddings for millions of items on a notebook, using [Vertex AI Embeddings API](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings) and [Vector Search](https://cloud.google.com/vertex-ai/docs/vector-search/overview).\n",
        "\n",
        "Embeddings API provides [batch predictions](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/batch-prediction-genai-embeddings#generative-ai-batch-embedding-python_vertex_ai_sdk) for generating text embeddings for large datasets, but there are limitations such as:\n",
        "\n",
        "1. some models including multimodal are not supported\n",
        "2. you cannot specify [task types](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types) for the text embeddings.\n",
        "\n",
        "For those cases, you may find this notebook useful.\n",
        "\n",
        "### Things to consider\n",
        "\n",
        "When processing a large dataset to generate embeddings, you'll need code that does:\n",
        "\n",
        "1. **Rate Limiting**: Control the frequency of your API calls to stay within your usage limits. This prevents exceeding quotas and ensures continuous processing for the large dataset.\n",
        "1. **Multithreading**: Process data and make API calls concurrently. This maximizes your quota usage and reduces the impact of latency, speeding up the overall process.\n",
        "1. **Checkpointing**: To avoid starting from scratch if an error occurs, Save the progress periodically. This lets you resume from the last saved point.\n",
        "\n",
        "This tutorial provides a straightforward example of how to implement these techniques effectively. You can choose the size of sample dataset from 10K, 100K or 1M items. With the default quota limit of Embedding API (as of Dec 2024), this example can generate:\n",
        "\n",
        "- 10K text embeddings in a few minutes\n",
        "- 100K text embeddings in roughly 5 minutes\n",
        "- 1M text embeddings in roughly one hour\n",
        "\n",
        "**Note**: This approach is intended for development purposes only. For production environments, consider building a robust MLOps pipeline with [Vertex AI Pipelines](https://cloud.google.com/vertex-ai/docs/pipelines/introduction) and [Dataflow](https://cloud.google.com/products/dataflow) to manage your embedding generation workflow."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W_PpmLpTUd_O"
      },
      "source": [
        "## Understanding quota for Embeddings API\n",
        "\n",
        "When using any Vertex AI API extensively, it's essential to ensure you have enough quota for your API calls. For the Embeddings API, the following pages detail the relevant quotas (Quota values listed below are current for us-central1 as of December 2024, but each region has different values and may change at any time. Please check the individual pages for the most up-to-date information).\n",
        "\n",
        "### [Quotas by region and by model](https://cloud.google.com/vertex-ai/generative-ai/docs/quotas#view-the-quotas-by-region-and-by-model)\n",
        "\n",
        "- base_model: text-embedding-005\t**1,500 requests per minute**\t(region: us-central1)\n",
        "- base_model: multimodalembedding\t**120 requests per minute** (region: us-central1)\n",
        "\n",
        "These are quotas that limits the number of API calls you can make. If you exceed this limit, you'll encounter ResourceExhausted errors. We will cover how to implement \"throttling\" in a later section to prevent this from happening. This will involve using code to control the rate of your API calls and stay within the quota.\n",
        "\n",
        "In many cases, you might need to request a quota increase to generate a large number of embeddings in a timely manner.  See the [View and manage quotas](https://cloud.google.com/docs/quotas/view-manage) page for instructions on how to check your current quota values and make increase requests.\n",
        "\n",
        "### [Text embedding limits](https://cloud.google.com/vertex-ai/generative-ai/docs/quotas#text-embedding-limits)\n",
        "\n",
        "- Each text embedding model request can have up to **250 input texts**\n",
        "- **20,000 tokens per request**\n",
        "\n",
        "This is the limits on the number of texts (or tokens of the texts) you can send with a single API call. Sending too many at once, such as the maximums shown above, can cause processing delays and make it difficult to monitor system health. We'll discuss strategies for finding the right balance later on."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61RBz8LLbxCR"
      },
      "source": [
        "## Get started\n",
        "\n",
        "Let's get started. In this section, we'll set up the necessary libraries and environment variables for this tutorial."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "No17Cw5hgx12"
      },
      "source": [
        "### Install Vertex AI SDK and other required packages\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gha9AHphiR-M"
      },
      "source": [
        "This line installs the Google Cloud AI Platform library, which is necessary for interacting with Vertex AI services, including the Embeddings API."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFy3H3aPgx12"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet google-cloud-aiplatform"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R5Xep4W9lq-Z"
      },
      "source": [
        "### Restart runtime\n",
        "\n",
        "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
        "\n",
        "The restart might take a minute or longer. After it's restarted, continue to the next step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "XRvKdaPDTznN"
      },
      "outputs": [],
      "source": [
        "import IPython\n",
        "\n",
        "app = IPython.Application.instance()\n",
        "app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SbmM4z7FOBpM"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ The kernel is going to restart. In Colab or Colab Enterprise, you might see an error message that says \"Your session crashed for an unknown reason.\" This is expected. Wait until it's finished before continuing to the next step. ⚠️</b>\n",
        "</div>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dmWOrTJ3gx13"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you're running this notebook on Google Colab, run the cell below to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NyKGtVQjgx13"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DF4l8DTdWgPY"
      },
      "source": [
        "### Set Google Cloud project information and initialize Vertex AI SDK\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com). Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).\n",
        "\n",
        "The following code defines project configuration variables. If you're running this notebook on Google Colab, you'll need to enter the `PROJECT_ID` manually. However, on Colab Enterprise and Vertex AI Workbench, it will be detected automatically. The code also generates a unique session ID based on the current timestamp."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "Nqwi-5ufWp_B"
      },
      "outputs": [],
      "source": [
        "# Use the environment variable if the user doesn't provide Project ID.\n",
        "import os\n",
        "\n",
        "# fmt: off\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n",
        "# fmt: on\n",
        "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
        "\n",
        "# generate an unique id for this session\n",
        "from datetime import datetime\n",
        "\n",
        "UID = datetime.now().strftime(\"%m%d%H%M%S\")\n",
        "f\"Unique ID for this session is: {UID}\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EdvJRUWRNGHE"
      },
      "source": [
        "## Prepare dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5303c05f7aa6"
      },
      "source": [
        "### Download GBIF dataset\n",
        "\n",
        "The following code downloads a [GBIF (Global Biodiversity Information Facility)](https://www.gbif.org/) dataset from Google Cloud Storage as a sample dataset that contains animal photos with their name and description. You can choose between 10k, 100k, or 1M items. It then uncompresses the downloaded gzip file to make it usable."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "53a72916-d6ad-4aaf-adc2-cbad67b4887c"
      },
      "outputs": [],
      "source": [
        "# choose dataset\n",
        "FILE_NAME = \"gbif_10k.json\"\n",
        "# FILE_NAME = \"gbif_100k.json\"\n",
        "# FILE_NAME = \"gbif_1m.json\"\n",
        "\n",
        "# download dataset\n",
        "!rm -rf ./gbif_*.json.gz\n",
        "!wget https://storage.googleapis.com/gcp-samples-ic0-ac/datasets/{FILE_NAME}.gz\n",
        "\n",
        "# uncompress the gzip file\n",
        "!rm -rf ./gbif_*.json\n",
        "!gzip -d ./gbif_*.json.gz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bV7xDxu1lt92"
      },
      "source": [
        "This code then loads the json file into a list. The first three items of the list are then printed. We will use `name` and `description` for generating text embeddings, and `gcsUri` for generating multimodal embeddings for the animal photo."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "SUhYN907pEho"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "\n",
        "with open(FILE_NAME) as f:\n",
        "    items = [json.loads(line) for line in f]\n",
        "items[:3]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "c_pGU33-myJJ"
      },
      "outputs": [],
      "source": [
        "# print the animal image of the first item\n",
        "from IPython.display import Image\n",
        "\n",
        "print(f\"Name: {items[0]['name']}\")\n",
        "print(f\"Description: {items[0]['description']}\")\n",
        "Image(url=items[0][\"url\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TEEq5WuB0sUl"
      },
      "source": [
        "## Generate text embeddings\n",
        "\n",
        "In this section, we will generate text embeddings for the sample dataset.\n",
        "\n",
        "### How to use Embedding API for text\n",
        "\n",
        "This code utilizes the Embeddings API to obtain [an embedding model](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#supported-models). It defines a function `generate_text_embeddings()` that accepts a list of items and generates their text embeddings using the model, with a specified [task type](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types) and dimensionality."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "nqcrrbh3365v"
      },
      "outputs": [],
      "source": [
        "import vertexai\n",
        "from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel\n",
        "\n",
        "TEXT_EMB_MODEL_NAME = \"text-embedding-005\"\n",
        "TEXT_EMB_TASK_TYPE = \"SEMANTIC_SIMILARITY\"\n",
        "TEXT_EMB_DIMENSIONALITY = 768\n",
        "\n",
        "vertexai.init(project=PROJECT_ID, location=LOCATION)\n",
        "text_emb_model = TextEmbeddingModel.from_pretrained(TEXT_EMB_MODEL_NAME)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "bj5Big704X_I"
      },
      "outputs": [],
      "source": [
        "from collections.abc import Callable\n",
        "from typing import Any\n",
        "\n",
        "def generate_text_embeddings(items: list[dict[str, Any]]) -> list[list[float]]:\n",
        "    \"\"\"Generate text embeddings for items.\"\"\"\n",
        "    # Combine name and description for embedding input.\n",
        "    names: list[str] = [item[\"name\"] + \" \" + item[\"description\"] for item in items]\n",
        "\n",
        "    # Prepare inputs for the text embedding model.\n",
        "    inputs: list[TextEmbeddingInput] = [\n",
        "        TextEmbeddingInput(name, TEXT_EMB_TASK_TYPE) for name in names\n",
        "    ]\n",
        "    kwargs = {\"output_dimensionality\": TEXT_EMB_DIMENSIONALITY}\n",
        "\n",
        "    # Get embeddings from the model.\n",
        "    return [emb.values for emb in text_emb_model.get_embeddings(inputs, **kwargs)]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "Ohmqa5TdVjkC"
      },
      "outputs": [],
      "source": [
        "# test it\n",
        "test_items = items[:1]\n",
        "print(generate_text_embeddings(test_items))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dprKR-piYddZ"
      },
      "source": [
        "### Define a worker thread\n",
        "\n",
        "Instead of generating embeddings for each item individually, we'll use multiple worker threads per API call to maximize quota utilization and make the entire job robust to errors. This code defines `run_worker_thread()`, which wraps `generate_text_embeddings()` to run as a worker thread, adding generated embeddings (or errors) into a queue."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "hChtZdfUYY44"
      },
      "outputs": [],
      "source": [
        "import queue\n",
        "\n",
        "def run_worker_thread(\n",
        "    generation_func: Callable[[list[dict[str, Any]], \"queue.Queue\"], None],\n",
        "    items: list[dict[str, Any]],\n",
        "    emb_queue: \"queue.Queue\",\n",
        "    err_queue: \"queue.Queue\",\n",
        ") -> None:\n",
        "    \"\"\"Runs a worker thread that generates embeddings with a single API call and handles\n",
        "    potential errors.\n",
        "\n",
        "    Args:\n",
        "        generation_func: A function that takes a list of items and returns their embeddings.\n",
        "        items: The list of items to process. Each item should be a dictionary.\n",
        "        emb_queue: The queue to put the generated embeddings into.\n",
        "        err_queue: The queue to put any encountered errors into.\n",
        "    \"\"\"\n",
        "    try:\n",
        "        embs = generation_func(items)\n",
        "        for i in range(len(items)):\n",
        "            emb_queue.put({\"id\": items[i][\"id\"], \"embedding\": embs[i]})\n",
        "    except Exception as e:\n",
        "        err_queue.put(str(e))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XvDX3FYi1GqZ"
      },
      "source": [
        "### Prepare a Cloud Storage bucket\n",
        "\n",
        "We will store the generated embeddings to a Cloud Storage bucket and folder created by this code."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "PyMhU6Vs1K5J"
      },
      "outputs": [],
      "source": [
        "from google.cloud import storage\n",
        "\n",
        "GCS_BUCKET = f\"{PROJECT_ID}-embs-{UID}\"\n",
        "GCS_TEXT_EMB_PATH = \"text_embs\"\n",
        "\n",
        "# create a bucket\n",
        "storage_client = storage.Client()\n",
        "storage_bucket = storage_client.bucket(GCS_BUCKET)\n",
        "storage_bucket = storage_client.create_bucket(storage_bucket, location=\"us-central1\")\n",
        "\n",
        "# create a folder for storing text embeddings\n",
        "empty_blob = storage_bucket.blob(GCS_TEXT_EMB_PATH + \"/\")\n",
        "empty_blob.upload_from_string(\"\")\n",
        "print(f\"\\nCreated text embedding folder: gs://{GCS_BUCKET}/{GCS_TEXT_EMB_PATH}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eDI7lYcZ4_Nv"
      },
      "source": [
        "### Prepare a queue manager thread\n",
        "\n",
        "This code defines a queue manager, `run_queue_manager_thread()`, that receives embeddings from the worker threads and handles potential errors. The manager periodically flushes successful embeddings to Cloud Storage using `flush_emb_queue()` and errors to a log file using `flush_err_queue()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "2k6cV2cVSyZ-"
      },
      "outputs": [],
      "source": [
        "QUEUE_FLUSH_THRESHOLD = 10000\n",
        "ERR_FILE_NAME = f\"err_{UID}.log\"\n",
        "\n",
        "is_queue_manager_enabled: bool = True\n",
        "\n",
        "\n",
        "def flush_emb_queue(emb_queue: queue.Queue, gcs_path: str, count: int) -> None:\n",
        "    \"\"\"Flushes the embedding queue to Cloud Storage.\n",
        "\n",
        "    Args:\n",
        "    emb_queue: The queue containing embedding dictionaries.\n",
        "    gcs_path: The destination path in Cloud Storage.\n",
        "    count: The number of embeddings to flush from the queue.\n",
        "    \"\"\"\n",
        "    timestamp: str = datetime.now().strftime(\"%m%d%H%M-%S%f\")\n",
        "    embs: str = \"\"\n",
        "    for _ in range(count):\n",
        "        emb: dict = emb_queue.get()\n",
        "        embs += json.dumps(emb) + \"\\n\"\n",
        "    gcs_file = storage_bucket.blob(f\"{gcs_path}/{timestamp}_embs.json\")\n",
        "    gcs_file.upload_from_string(embs, content_type=\"application/json\")\n",
        "    print(f\"Uploaded {count} embeddings to {gcs_file.name}\")\n",
        "\n",
        "\n",
        "def flush_err_queue(err_queue: queue.Queue) -> None:\n",
        "    \"\"\"Flushes the error queue to the error log file.\"\"\"\n",
        "    with open(ERR_FILE_NAME, \"a\") as file:\n",
        "        while not err_queue.empty():\n",
        "            error = err_queue.get()\n",
        "            file.write(f\"Error {error}\\n\")\n",
        "            print(f\"Error {error}\")\n",
        "\n",
        "\n",
        "def run_queue_manager_thread(\n",
        "    emb_queue: queue.Queue, err_queue: queue.Queue, gcs_path: str\n",
        ") -> None:\n",
        "    \"\"\"Runs the queue manager thread, which monitors and flushes the embedding and error queues.\n",
        "\n",
        "    Args:\n",
        "        emb_queue: The queue for storing embeddings.\n",
        "        err_queue: The queue for storing errors.\n",
        "        gcs_path: The path to Cloud Storage where embeddings will be stored.\n",
        "    \"\"\"\n",
        "    # Continue managing the queues while enabled\n",
        "    global is_queue_manager_enabled\n",
        "    while is_queue_manager_enabled:\n",
        "        time.sleep(0.1)\n",
        "\n",
        "        # Flush the embedding queue if it exceeds the threshold\n",
        "        if emb_queue.qsize() > QUEUE_FLUSH_THRESHOLD:\n",
        "            flush_emb_queue(emb_queue, gcs_path, QUEUE_FLUSH_THRESHOLD)\n",
        "\n",
        "        # Flush the error queue if it contains any errors\n",
        "        if err_queue.qsize() > 0:\n",
        "            flush_err_queue(err_queue)\n",
        "\n",
        "    # Perform a final flush of both queues when the queue manager is disabled\n",
        "    flush_emb_queue(emb_queue, gcs_path, emb_queue.qsize())\n",
        "    flush_err_queue(err_queue)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4RsMKj_hHJhS"
      },
      "source": [
        "### Define the main loop\n",
        "\n",
        "`generate_embeddings()` is the main loop that performs the following: 1) initializes and starts the queue manager thread, 2) runs a loop that spawns a new worker thread for each API call, adhering to the quota limit, and 3) ensures all embeddings and any errors are processed from the queue."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "f98aab72-0540-4b05-a1ee-54989a744baa"
      },
      "outputs": [],
      "source": [
        "import threading\n",
        "import time\n",
        "\n",
        "from tqdm.notebook import tqdm\n",
        "\n",
        "def generate_embeddings(\n",
        "    generation_func: Callable[[list[dict[str, Any]], \"queue.Queue\"], None],\n",
        "    reqs_per_min_quota: int,\n",
        "    items_per_req: int,\n",
        "    gcs_path: str,\n",
        "    items: list[dict[str, Any]],\n",
        ") -> None:\n",
        "    \"\"\"Generates embeddings with throttling and error handling.\n",
        "\n",
        "    Args:\n",
        "        generation_func: The function used to generate the embeddings.\n",
        "        reqs_per_min_quota: The maximum number of requests allowed per minute for the model.\n",
        "        items_per_req: The number of items to include in each request.\n",
        "        items: The dataset we'll be working with.\n",
        "    \"\"\"\n",
        "    # All threads.\n",
        "    threads: list[threading.Thread] = []\n",
        "\n",
        "    # Throttling interval.\n",
        "    req_interval: float = 1.0 / (reqs_per_min_quota / 60)\n",
        "\n",
        "    # Start queue manager thread.\n",
        "    global is_queue_manager_enabled\n",
        "    is_queue_manager_enabled = True\n",
        "    emb_queue: queue.Queue = queue.Queue()\n",
        "    err_queue: queue.Queue = queue.Queue()\n",
        "    queue_manager_thread: threading.Thread = threading.Thread(\n",
        "        target=run_queue_manager_thread, args=(emb_queue, err_queue, gcs_path)\n",
        "    )\n",
        "    queue_manager_thread.start()\n",
        "\n",
        "    # Generate embeddings.\n",
        "    for i in tqdm(range(0, len(items), items_per_req)):\n",
        "        # Throttle requests.\n",
        "        time.sleep(req_interval)\n",
        "\n",
        "        # Start a worker thread.\n",
        "        items_slice: list[dict[str, Any]] = items[i : i + items_per_req]\n",
        "        worker_thread: threading.Thread = threading.Thread(\n",
        "            target=run_worker_thread,\n",
        "            args=(generation_func, items_slice, emb_queue, err_queue),\n",
        "        )\n",
        "        worker_thread.start()\n",
        "        threads.append(worker_thread)\n",
        "\n",
        "    # Wait for all worker threads to finish.\n",
        "    print(f\"Waiting for {len(threads)} threads to finish...\")\n",
        "    for i in tqdm(range(0, len(threads), 1)):\n",
        "        threads[i].join()\n",
        "\n",
        "    # Wait for the queue manager to stop.\n",
        "    print(\"Waiting for the queue manager to finish...\")\n",
        "    is_queue_manager_enabled = False\n",
        "    queue_manager_thread.join()\n",
        "\n",
        "    # Print error count.\n",
        "    if os.path.exists(ERR_FILE_NAME):\n",
        "        with open(ERR_FILE_NAME) as f:\n",
        "            error_count: int = len(f.readlines())\n",
        "        print(f\"{error_count} errors recorded in {ERR_FILE_NAME}\")\n",
        "    else:\n",
        "        print(\"No errors recorded\")\n",
        "    print(\"Done!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4gBAj_Evx_JM"
      },
      "source": [
        "### Run it\n",
        "\n",
        "Okay, let's run this! We'll use the following parameters:\n",
        "\n",
        "- `generation_func`: The function used to generate the embeddings.\n",
        "- `reqs_per_min_quota`: The maximum number of requests allowed per minute for the model.\n",
        "- `items_per_req`: The number of items to include in each request.\n",
        "- `items`: The dataset we'll be working with.\n",
        "\n",
        "Regarding `items_per_req`, the text embedding model allows a maximum of 250 items per request, as mentioned earlier. However, setting it to the maximum value slows down the entire throughput. You might want to choose a more moderate value, like 20, for better performance."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "KqT4-hoTBfTp"
      },
      "outputs": [],
      "source": [
        "# Start generating text embeddings\n",
        "generate_embeddings(\n",
        "    generation_func=generate_text_embeddings,\n",
        "    reqs_per_min_quota=1500,\n",
        "    items_per_req=20,\n",
        "    gcs_path=GCS_TEXT_EMB_PATH,\n",
        "    items=items,\n",
        ")\n",
        "\n",
        "print(f\"\\nCreated text embeddings on folder: gs://{GCS_BUCKET}/{GCS_TEXT_EMB_PATH}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "306cc4861fc7"
      },
      "source": [
        "### Tips for a long running job\n",
        "\n",
        "If you run the code for a long time, you might occasionally lose the connection between your browser and the runtime, making it difficult to track the job's progress. For these long-running jobs, you should convert this notebook to a `.py` file and run it as a Python process using the `nohup` command."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ooNs0butatqs"
      },
      "source": [
        "### Monitoring errors\n",
        "\n",
        "While embedding generation is in progress, you can monitor the error log for any recorded errors.\n",
        "\n",
        "### Monitoring quota usage\n",
        "\n",
        "You can also monitor your quota usage in the Google Cloud console. To do this, follow the instructions under [View quotas in the Google Cloud console](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console). Use the filter and type \"embedding\" to find the relevant quota, then click the Show usage chart button on the right to see your usage details.\n",
        "\n",
        "Typically, you'll find that this tool utilizes around 80-90% of the specified quota limit.\n",
        "\n",
        "<br/>\n",
        "<center><img src=\"https://storage.googleapis.com/gcp-samples-ic0-ac/images/quota_usage.png\" width=\"400\"/></center>\n",
        "<br/>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "djBMLcin2XoN"
      },
      "source": [
        "### (Optional) Creating a Vector Search index\n",
        "\n",
        "Optionally, you can try using the generated embeddings for building [Vertex AI Vector Search](https://cloud.google.com/vertex-ai/docs/vector-search/overview) index for semantic search.\n",
        "\n",
        "Check out the JSONL files stored on the Cloud Storage folder by using [Cloud Storage console](https://console.cloud.google.com/storage/browser). The file looks like:\n",
        "\n",
        "```\n",
        "{\n",
        "    \"id\": \"2251880622\",\n",
        "    \"embedding\": [-0.005421683192253113, 0.030090663582086563,...\n",
        "}\n",
        "...\n",
        "```\n",
        "\n",
        "This format conforms the [input data format of Vertex AI Vector Search](https://cloud.google.com/vertex-ai/docs/vector-search/setup/format-structure). You can directly use the files to create an index by either:\n",
        "\n",
        "- Using [Vector Search console](https://cloud.google.com/vertex-ai/docs/vector-search/create-manage-index#create_index-console): Specify the Cloud Storage path `gs://...` on the create new index dialog.\n",
        "- Using [Vector Search Python SDK](https://cloud.google.com/vertex-ai/docs/vector-search/create-manage-index#create_index-python_vertex_ai_sdk): Specify the Cloud Storage path `gs://...` as `contents_delta_uri` parameter of `create_tree_ah_index()`.\n",
        "\n",
        "For further usage of Vector Search, refer to [the product documentation](https://cloud.google.com/vertex-ai/docs/vector-search/overview)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qV7AVEqh0Qnz"
      },
      "source": [
        "## Generate Multimodal Embeddings\n",
        "\n",
        "The tool above can be reused to generate other types of embeddings by defining a new `generation_func`. The following code demonstrates how to generate multimodal embeddings for the animal photos."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "Bb4lI3x1CrlP"
      },
      "outputs": [],
      "source": [
        "# Create a folder for storing multimodal embeddings\n",
        "GCS_MM_EMB_PATH = \"mm_embs\"\n",
        "empty_blob = storage_bucket.blob(GCS_MM_EMB_PATH + \"/\")\n",
        "empty_blob.upload_from_string(\"\")\n",
        "print(f\"\\nCreated multimodal embedding folder: gs://{GCS_BUCKET}/{GCS_MM_EMB_PATH}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "Ijlo_3b43KHk"
      },
      "outputs": [],
      "source": [
        "# Prepare a multimodal embedding model\n",
        "from vertexai.vision_models import (\n",
        "    Image,\n",
        "    MultiModalEmbeddingModel,\n",
        "    MultiModalEmbeddingResponse,\n",
        ")\n",
        "\n",
        "MM_EMB_MODEL_NAME = \"multimodalembedding\"\n",
        "MM_EMB_DIMENSIONALITY = 1408\n",
        "\n",
        "mm_emb_model = MultiModalEmbeddingModel.from_pretrained(MM_EMB_MODEL_NAME)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "eawEVliy-f6E"
      },
      "outputs": [],
      "source": [
        "def generate_mm_embeddings(items: list[dict[str, Any]]) -> list[list[float]]:\n",
        "    \"\"\"Generate multimodal embeddings for items.\"\"\"\n",
        "    # Extract GCS URIs from items.\n",
        "    gcsUris: list[str] = [item[\"gcsUri\"] for item in items]\n",
        "\n",
        "    # Load images from the GCS URIs.\n",
        "    images: list[Image] = [Image.load_from_file(gcsUri) for gcsUri in gcsUris]\n",
        "\n",
        "    # Get multimodal embeddings.\n",
        "    embs: list[MultiModalEmbeddingResponse] = [\n",
        "        mm_emb_model.get_embeddings(image=image, dimension=MM_EMB_DIMENSIONALITY)\n",
        "        for image in images\n",
        "    ]\n",
        "\n",
        "    # Return the image embeddings.\n",
        "    return [emb.image_embedding for emb in embs]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "5erkBoCeEFqW"
      },
      "outputs": [],
      "source": [
        "# Run it\n",
        "generate_embeddings(\n",
        "    generation_func=generate_mm_embeddings,\n",
        "    reqs_per_min_quota=120,\n",
        "    items_per_req=1,\n",
        "    gcs_path=GCS_MM_EMB_PATH,\n",
        "    items=items,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TsLypcRV8alx"
      },
      "source": [
        "For generating multimodal embeddings, there is a couple of difference from the text embeddings:\n",
        "\n",
        "- `reqs_per_min_quota`: The default quota limit is lower than text embedding mode.\n",
        "- `items_per_req`: The multimodal model can only take single image per request.\n",
        "\n",
        "With these differences, it would take much longer time to generate embeddings for all items."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2a4e033321ad"
      },
      "source": [
        "## Cleaning up\n",
        "\n",
        "That was it. Finally, remove the Cloud Storage bucket that use used for this tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AVG9aMs8eeTt"
      },
      "outputs": [],
      "source": [
        "storage_bucket.delete(force=True)\n",
        "print(\"Storage bucket deleted\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "large-embs-generation-for-vvs.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
